{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "32f3796f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7f8f42d45ca0>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "from collections import defaultdict\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from transformer_lens import (\n",
    "    HookedTransformer,\n",
    ")\n",
    "from toxicity.figures.fig_utils import load_hooked, get_svd\n",
    "from constants import MODEL_DIR, DATA_DIR\n",
    "\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "43d04809",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gpt2-medium into HookedTransformer\n",
      "Loaded pretrained model gpt2-medium into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model = load_hooked(\n",
    "    \"gpt2-medium\",\n",
    "    os.path.join(MODEL_DIR, \"dpo.pt\"),\n",
    ")\n",
    "gpt2 = HookedTransformer.from_pretrained(\"gpt2-medium\")\n",
    "gpt2.tokenizer.padding_side = \"left\"\n",
    "gpt2.tokenizer.pad_token_id = gpt2.tokenizer.eos_token_id\n",
    "\n",
    "toxic_vector = torch.load(os.path.join(MODEL_DIR, \"probe.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d84999a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "with open(\n",
    "    os.path.join(DATA_DIR, \"intervene_data/challenge_prompts.jsonl\"), \"r\"\n",
    ") as file_p:\n",
    "    data = file_p.readlines()\n",
    "\n",
    "prompts = [json.loads(x.strip())[\"prompt\"] for x in data]\n",
    "tokenized_prompts = model.to_tokens(prompts, prepend_bos=True).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "16fc0a84",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "_, scores_gpt2 = get_svd(gpt2, toxic_vector, 128)\n",
    "\n",
    "mlps_by_layer = {}\n",
    "for _score_obj in scores_gpt2:\n",
    "    layer = _score_obj[2]\n",
    "    if layer not in mlps_by_layer:\n",
    "        mlps_by_layer[layer] = []\n",
    "    mlps_by_layer[layer].append(_score_obj[1])\n",
    "\n",
    "vectors_of_interest = [\n",
    "    (_score_obj[2], _score_obj[1], _score_obj[0])\n",
    "    for _score_obj in scores_gpt2[:64]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ef1da2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Grabbing mlp mids...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 99%|███████████████████████████████████████████████████████████████████████████████████████████ | 297/300 [12:57<00:08,  2.73s/it]"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "gpt2_acts_of_interest = defaultdict(list)\n",
    "dpo_acts_of_interest = defaultdict(list)\n",
    "sample_size = tokenized_prompts.shape[0]\n",
    "batch_size = 4\n",
    "print(\"Grabbing mlp mids...\")\n",
    "for idx in tqdm(range(0, sample_size, batch_size)):\n",
    "    batch = tokenized_prompts[idx : idx + batch_size, :]\n",
    "    dpo_batch = batch.clone()\n",
    "\n",
    "    for timestep in range(20):\n",
    "        with torch.inference_mode():\n",
    "            _, cache = gpt2.run_with_cache(batch)\n",
    "\n",
    "        sampled = gpt2.unembed(cache[\"ln_final.hook_normalized\"]).argmax(-1)[\n",
    "            :, -1\n",
    "        ]\n",
    "        for _vec in vectors_of_interest:\n",
    "            _layer = _vec[0]\n",
    "            _idx = _vec[1]\n",
    "            mlp_mid = cache[f\"blocks.{_layer}.mlp.hook_post\"][:, -1, _idx]\n",
    "            gpt2_acts_of_interest[(_layer, _idx)].extend(mlp_mid.tolist())\n",
    "\n",
    "        with torch.inference_mode():\n",
    "            _, cache = model.run_with_cache(dpo_batch)\n",
    "        sampled = model.unembed(cache[\"ln_final.hook_normalized\"]).argmax(-1)[\n",
    "            :, -1\n",
    "        ]\n",
    "\n",
    "        for _vec in vectors_of_interest:\n",
    "            _layer = _vec[0]\n",
    "            _idx = _vec[1]\n",
    "            mlp_mid = cache[f\"blocks.{_layer}.mlp.hook_post\"][:, -1, _idx]\n",
    "            dpo_acts_of_interest[(_layer, _idx)].extend(mlp_mid.tolist())\n",
    "\n",
    "        batch = torch.concat([batch, sampled.unsqueeze(-1)], dim=-1)\n",
    "        dpo_batch = torch.concat([dpo_batch, sampled.unsqueeze(-1)], dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e700b355",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "d_mlp = model.cfg.d_mlp\n",
    "dpo_acts_mean = {}\n",
    "gpt2_acts_mean = {}\n",
    "num_mlps = 5\n",
    "for _vec in vectors_of_interest[:num_mlps]:\n",
    "\n",
    "    _layer = _vec[0]\n",
    "    _idx = _vec[1]\n",
    "    gpt2_acts_mean[(_layer, _idx)] = np.mean(\n",
    "        gpt2_acts_of_interest[(_layer, _idx)]\n",
    "    )\n",
    "    dpo_acts_mean[(_layer, _idx)] = np.mean(\n",
    "        dpo_acts_of_interest[(_layer, _idx)]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a85e953",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "raw_data = []\n",
    "num_mlps = 5\n",
    "for _vec in vectors_of_interest[:num_mlps]:\n",
    "    _layer = _vec[0]\n",
    "    _idx = _vec[1]\n",
    "\n",
    "    raw_data.append(\n",
    "        {\n",
    "            \"MLP\": f\"L:{_layer}\\nIdx:{_idx}\",\n",
    "            \"Mean Activation\": dpo_acts_mean[(_layer, _idx)].item(),\n",
    "            \"Model\": \"DPO\",\n",
    "        }\n",
    "    )\n",
    "\n",
    "    raw_data.append(\n",
    "        {\n",
    "            \"MLP\": f\"L:{_layer}\\nIdx:{_idx}\",\n",
    "            \"Mean Activation\": gpt2_acts_mean[(_layer, _idx)].item(),\n",
    "            \"Model\": \"GPT2\",\n",
    "        }\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be503524",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "\n",
    "data = pd.DataFrame(raw_data)\n",
    "sns.set_theme(context=\"paper\", style=\"ticks\", rc={\"lines.linewidth\": 1})\n",
    "\n",
    "sns.catplot(\n",
    "    data=data,\n",
    "    x=\"MLP\",\n",
    "    y=\"Mean Activation\",\n",
    "    hue=\"Model\",\n",
    "    hue_order=[\"GPT2\", \"DPO\"],\n",
    "    height=2,\n",
    "    aspect=3.25 / 2,\n",
    "    kind=\"bar\",\n",
    "    legend_out=False,\n",
    ")\n",
    "\n",
    "\n",
    "plt.savefig(\"activation_drops.pdf\", bbox_inches=\"tight\", dpi=1200)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
